import json
import math
import pickle

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import re
import tokenize
from io import StringIO
from tqdm import tqdm
from collections import Counter
def remove_comments_and_docstrings(source, lang):
    if lang in ["python"]:
        """
        Returns 'source' minus comments and docstrings.
        """
        io_obj = StringIO(source)
        out = ""
        prev_toktype = tokenize.INDENT
        last_lineno = -1
        last_col = 0
        for tok in tokenize.generate_tokens(io_obj.readline):
            token_type = tok[0]
            token_string = tok[1]
            start_line, start_col = tok[2]
            end_line, end_col = tok[3]
            tok[4]
            if start_line > last_lineno:
                last_col = 0
            if start_col > last_col:
                out += " " * (start_col - last_col)
            # Remove comments:
            if token_type == tokenize.COMMENT:
                pass
            # This series of conditionals removes docstrings:
            elif token_type == tokenize.STRING:
                if prev_toktype != tokenize.INDENT:
                    # This is likely a docstring; double-check we're not inside an operator:
                    if prev_toktype != tokenize.NEWLINE:
                        if start_col > 0:
                            out += token_string
            else:
                out += token_string
            prev_toktype = token_type
            last_col = end_col
            last_lineno = end_line
        temp = []
        for x in out.split("\n"):
            if x.strip() != "":
                temp.append(x)
        return "\n".join(temp)
    elif lang in ["ruby"]:
        return source
    else:
        def replacer(match):
            s = match.group(0)
            if s.startswith("/"):
                return " "  # note: a space and not an empty string
            else:
                return s

        pattern = re.compile(
            r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"',
            re.DOTALL | re.MULTILINE,
        )
        temp = []
        for x in re.sub(pattern, replacer, source).split("\n"):
            if x.strip() != "":
                temp.append(x)
        return "\n".join(temp)

def remove_comments(code):
    try:
        return remove_comments_and_docstrings(code, 'python')
    except:
        return 'ERROR'

def cal_tools(key):
    return len(stat[key]['action_names'])

# Load data
with open('./query_and_description_2.json', 'r') as fp:
    data = json.load(fp)

with open('../../data/statistics.pkl', 'rb') as fp:
    stat = pickle.load(fp)

data = [sample for sample in data if sample['key'] in stat.keys()]

data = pd.DataFrame(data)
data = data[~data['description'].str.strip().eq('')]
data = data[~data['line_by_line'].str.strip().eq('')]

# for debug only
# data = data.sample(500)

# data['num_tools'] = data['key'].map(lambda x: math.ceil(np.log2(cal_tools(x) + 1)))
data['code'] = data['line_by_line'].map(lambda x: remove_comments(x))

data['num_code_lines'] = data['code'].map(lambda x: math.ceil(np.log2(x.count('\n') + 1)))

# data['stratify_feature'] = data['num_tools'].astype(str) + "_" + data['num_code_lines'].astype(str) + "_" + data['num_conditions'].astype(str) + "_" + data['num_loops'].astype(str)
data['stratify_feature'] = data['num_code_lines'].astype(str)

train_data = data[data['code'] == 'ERROR']

remaining_data = data[data['code'] != 'ERROR']

print(Counter(remaining_data['stratify_feature'].values))

print(remaining_data[remaining_data['num_code_lines'] == 0])
train_rest, reward = train_test_split(
    remaining_data,
    test_size=0.2,
    stratify=remaining_data['stratify_feature'],
    random_state=42
)

train_rest, test = train_test_split(
    train_rest,
    test_size=0.1 / 0.8,  # adjust for remaining data after test1 split
    stratify=train_rest['stratify_feature'],
    random_state=42
)

train_rest, dev = train_test_split(
    train_rest,
    test_size=0.1 / 0.7,  # adjust for remaining data after test1 and test2 splits
    stratify=train_rest['stratify_feature'],
    random_state=42
)

# Combine all training data
train_data = pd.concat([train_data, train_rest])

# Extract unique keys for each set and store them in a dictionary
dataset_keys = {
    'train': train_data['key'].unique().tolist(),
    'reward': reward['key'].unique().tolist(),
    'dev': dev['key'].unique().tolist(),
    'test': test['key'].unique().tolist(),
}

# Print the dictionary to check the result
print("Keys stored in different sets:")
# print(json.dumps(dataset_keys, indent=4))

# Optionally save the dictionary to a file if needed
with open('dataset_split_keys.json', 'w') as fp:
    json.dump(dataset_keys, fp, indent=4)

# Check dataset sizes
print(f"Training set size: {len(train_data)}")
print(f"Reward set size: {len(reward)}")
print(f"Dev set size: {len(dev)}")
print(f"Test set size: {len(test)}")